# Script modified to extract connection-level features [Hayes et al. USENIX 2016]

import csv
import sys
import math
from sys import stdout
import numpy as np
import operator
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import cross_val_score
from sklearn import metrics
from sklearn import tree
import sklearn.metrics as skm
import scipy
import random
import os
from collections import defaultdict
import argparse
from itertools import chain
import subprocess

# re-seed the generator
#np.random.seed(1234)

#Note: dictionary_() will extract features; -1: INCOMING, 1: OUTGOING Tor Cell
#Cell file format: "topk#time direction(-/+)cellsize"

"""Feeder functions"""

def neighborhood(iterable):
    iterator = iter(iterable)
    prev = (0)
    item = next(iterator)  # throws StopIteration if empty.
    for nex in iterator:
        yield (prev,item,nex)
        prev = item
        item = nex
    yield (prev,item,None)

def chunkIt(seq, num):
  avg = len(seq) / float(num)
  out = []
  last = 0.0
  while last < len(seq):
    out.append(seq[int(last):int(last + avg)])
    last += avg
  return out

"""Non-feeder functions"""

def get_pkt_list(trace_data):
    first_line = trace_data[0].rstrip()
    first_line = first_line.split("\t")

    first_time = float(first_line[0])
    dta = []
    for line in trace_data:
        if "##HOST_FTS" in line:
            continue
        a = line.rstrip()
        b = a.split("\t")
        if "e-" in b[0]:
            dr = b[1]
            print("Exponent in total seconds: ", b)
            b = [0.0, dr]
        #print(b, float(b[0])- first_time)

        if float(b[1]) > 0:
            dta.append(((float(b[0])- first_time), 1))
        else:
            dta.append(((float(b[0]) - first_time), -1))
    return dta


def In_Out(list_data):
    In = []
    Out = []
    for p in list_data:
        if p[1] == -1:
            In.append(p)
        if p[1] == 1:
            Out.append(p)
    return In, Out

############### TIME FEATURES #####################

def inter_pkt_time(list_data):
    times = [x[0] for x in list_data]
    temp = []
    for elem,next_elem in zip(times, times[1:]+[times[0]]):
        temp.append(next_elem-elem)
    return temp[:-1]

def interarrival_times(list_data):
    In, Out = In_Out(list_data)
    IN = inter_pkt_time(In)
    OUT = inter_pkt_time(Out)
    TOTAL = inter_pkt_time(list_data)
    return IN, OUT, TOTAL

def interarrival_maxminmeansd_stats(list_data):
    interstats = []
    In, Out, Total = interarrival_times(list_data)
    if In and Out:
        avg_in = sum(In)/float(len(In))
        avg_out = sum(Out)/float(len(Out))
        avg_total = sum(Total)/float(len(Total))
        interstats.append((max(In), max(Out), max(Total), avg_in, avg_out, avg_total, np.std(In), np.std(Out), np.std(Total), np.percentile(In, 75), np.percentile(Out, 75), np.percentile(Total, 75)))
    elif Out and not In:
        avg_out = sum(Out)/float(len(Out))
        avg_total = sum(Total)/float(len(Total))
        interstats.append((0, max(Out), max(Total), 0, avg_out, avg_total, 0, np.std(Out), np.std(Total), 0, np.percentile(Out, 75), np.percentile(Total, 75)))
    elif In and not Out:
        avg_in = sum(In)/float(len(In))
        avg_total = sum(Total)/float(len(Total))
        interstats.append((max(In), 0, max(Total), avg_in, 0, avg_total, np.std(In), 0, np.std(Total), np.percentile(In, 75), 0, np.percentile(Total, 75)))
    else:
        interstats.extend(([0]*15))
    return interstats

def time_percentile_stats(trace_data):
    Total = get_pkt_list(trace_data)
    In, Out = In_Out(Total)
    In1 = [x[0] for x in In]
    Out1 = [x[0] for x in Out]
    Total1 = [x[0] for x in Total]
    STATS = []
    if In1:
        STATS.append(np.percentile(In1, 25)) # return 25th percentile
        STATS.append(np.percentile(In1, 50))
        STATS.append(np.percentile(In1, 75))
        STATS.append(np.percentile(In1, 100))
    if not In1:
        STATS.extend(([0]*4))
    if Out1:
        STATS.append(np.percentile(Out1, 25)) # return 25th percentile
        STATS.append(np.percentile(Out1, 50))
        STATS.append(np.percentile(Out1, 75))
        STATS.append(np.percentile(Out1, 100))
    if not Out1:
        STATS.extend(([0]*4))
    if Total1:
        STATS.append(np.percentile(Total1, 25)) # return 25th percentile
        STATS.append(np.percentile(Total1, 50))
        STATS.append(np.percentile(Total1, 75))
        STATS.append(np.percentile(Total1, 100))
    if not Total1:
        STATS.extend(([0]*4))
    return STATS

def number_pkt_stats(trace_data):
    Total = get_pkt_list(trace_data)
    In, Out = In_Out(Total)
    return len(In), len(Out), len(Total)

def first_and_last_30_pkts_stats(trace_data):
    Total = get_pkt_list(trace_data)
    first30 = Total[:30]
    last30 = Total[-30:]
    first30in = []
    first30out = []
    for p in first30:
        if p[1] == -1:
            first30in.append(p)
        if p[1] == 1:
            first30out.append(p)
    last30in = []
    last30out = []
    for p in last30:
        if p[1] == -1:
            last30in.append(p)
        if p[1] == 1:
            last30out.append(p)
    stats= []
    stats.append(len(first30in))
    stats.append(len(first30out))
    stats.append(len(last30in))
    stats.append(len(last30out))
    return stats

#concentration of outgoing packets in chunks of 20 packets
def pkt_concentration_stats(trace_data):
    Total = get_pkt_list(trace_data)
    chunks= [Total[x:x+20] for x in range(0, len(Total), 20)]
    concentrations = []
    for item in chunks:
        c = 0
        for p in item:
            if p[1] == 1:
                c+=1
        concentrations.append(c)
    return np.std(concentrations), sum(concentrations)/float(len(concentrations)), np.percentile(concentrations, 50), min(concentrations), max(concentrations), concentrations

#Average number packets sent and received per second
def number_per_sec(trace_data):
    Total = get_pkt_list(trace_data)
    last_time = Total[-1][0]
    last_second = math.ceil(last_time)

    temp = []
    l = []
    for i in range(1, int(last_second)+1):
        c = 0
        for p in Total:
            if p[0] <= i:
                c+=1
        temp.append(c)
    for prev,item,next in neighborhood(temp):
        x = item - prev
        l.append(x)
    avg_number_per_sec = sum(l)/float(len(l))
    return avg_number_per_sec, np.std(l), np.percentile(l, 50), min(l), max(l), l

#Variant of packet ordering features from http://cacr.uwaterloo.ca/techreports/2014/cacr2014-05.pdf
def avg_pkt_ordering_stats(trace_data):
    Total = get_pkt_list(trace_data)
    c1 = 0
    c2 = 0
    temp1 = []
    temp2 = []
    for p in Total:
        if p[1] == 1:
            temp1.append(c1)
        c1+=1
        if p[1] == -1:
            temp2.append(c2)
        c2+=1
    avg_in = sum(temp1)/float(len(temp1))
    avg_out = sum(temp2)/float(len(temp2))

    return avg_in, avg_out, np.std(temp1), np.std(temp2)

def perc_inc_out(trace_data):
    Total = get_pkt_list(trace_data)
    In, Out = In_Out(Total)
    percentage_in = len(In)/float(len(Total))
    percentage_out = len(Out)/float(len(Total))
    return percentage_in, percentage_out

############### SIZE FEATURES #####################

def total_size(list_data):
    return sum([x[1] for x in list_data])

def in_out_size(list_data):
    In, Out = In_Out(list_data)
    size_in = sum([x[1] for x in In])
    size_out = sum([x[1] for x in Out])
    return size_in, size_out

def average_total_pkt_size(list_data):
    return np.mean([x[1] for x in list_data])

def average_in_out_pkt_size(list_data):
    In, Out = In_Out(list_data)
    average_size_in = np.mean([x[1] for x in In])
    average_size_out = np.mean([x[1] for x in Out])
    return average_size_in, average_size_out

def variance_total_pkt_size(list_data):
    return np.var([x[1] for x in list_data])

def variance_in_out_pkt_size(list_data):
    In, Out = In_Out(list_data)
    var_size_in = np.var([x[1] for x in In])
    var_size_out = np.var([x[1] for x in Out])
    return var_size_in, var_size_out

def std_total_pkt_size(list_data):
    return np.std([x[1] for x in list_data])

def std_in_out_pkt_size(list_data):
    In, Out = In_Out(list_data)
    std_size_in = np.std([x[1] for x in In])
    std_size_out = np.std([x[1] for x in Out])
    return std_size_in, std_size_out

def max_in_out_pkt_size(list_data):
    In, Out = In_Out(list_data)
    max_size_in = max([x[1] for x in In])
    max_size_out = max([x[1] for x in Out])
    return max_size_in, max_size_out

def unique_pkt_lengths(list_data):
    pass

############### FEATURE FUNCTION #####################
def get_hostfts(trace_data):
    hostfts = [0] * 40
    for line in trace_data:
        if "HOST_FTS" in line:
            hostft = line.rstrip().split(",")[1:]
            hostfts = []
            for x in hostft:
                if x == '0':
                    hostfts += [int(x)]
                else:
                    hostfts += [float(x)]
            assert len(hostfts) == 40
        else:
            continue

    return hostfts

def get_ft_labels(ALL_FEATURES):
    prev = 0
    next = 0
    # TIME Features
    ALL_FEATURES.extend(intertimestats)
    prev = len(ALL_FEATURES)
    print("Inter packet time stats: ", prev) #0-11

    ALL_FEATURES.extend(timestats)
    next = len(ALL_FEATURES)
    print("Time stats: ", prev, next) #12-23

    prev = next
    next = len(ALL_FEATURES)
    ALL_FEATURES.extend(number_pkts)
    print("Number of pkts: ", prev, next) #24-26

    prev = next
    next = len(ALL_FEATURES)
    ALL_FEATURES.extend(thirtypkts)
    print("Thirty packets stats: ", prev, next) #27-30

    prev = next
    next = len(ALL_FEATURES)
    ALL_FEATURES.append(stdconc)
    print("Std pkt conc: ", prev, next) #31

    prev = next
    next = len(ALL_FEATURES)
    ALL_FEATURES.append(avgconc) #32
    print("Avg pkt conc: ", prev, next)

    prev = next
    next = len(ALL_FEATURES)
    ALL_FEATURES.append(avg_per_sec)
    print("Avg per sec: ", prev, next)

    prev = next
    next = len(ALL_FEATURES)
    ALL_FEATURES.append(std_per_sec)
    print("Std per sec: ", prev, next)

    prev = next
    next = len(ALL_FEATURES)
    ALL_FEATURES.append(avg_order_in)
    print("avg order in: ", prev, next)

    prev = next
    next = len(ALL_FEATURES)
    ALL_FEATURES.append(avg_order_out)
    print("avg order out: ", prev, next)

    prev = next
    next = len(ALL_FEATURES)
    ALL_FEATURES.append(std_order_in)
    print("Std order in: ", prev, next)

    prev = next
    next = len(ALL_FEATURES)
    ALL_FEATURES.append(std_order_out)
    print("std order out: ", prev, next)

    prev = next
    next = len(ALL_FEATURES)
    ALL_FEATURES.append(medconc)
    print("medconc: ", prev, next)

    prev = next
    next = len(ALL_FEATURES)
    ALL_FEATURES.append(med_per_sec)
    print("med per sec: ", prev, next)

    prev = next
    next = len(ALL_FEATURES)
    ALL_FEATURES.append(min_per_sec)
    print("min per sec: ", prev, next)

    prev = next
    next = len(ALL_FEATURES)
    ALL_FEATURES.append(max_per_sec)
    print("max per sec: ", prev, next)

    prev = next
    next = len(ALL_FEATURES)
    ALL_FEATURES.append(maxconc)
    print("max conc: ", prev, next)

    prev = next
    next = len(ALL_FEATURES)
    ALL_FEATURES.append(perc_in)
    print("% in: ", prev, next)

    prev = next
    next = len(ALL_FEATURES)
    ALL_FEATURES.append(perc_out)
    print("% out : ", prev, next)

    prev = next
    next = len(ALL_FEATURES)
    ALL_FEATURES.extend(altconc)
    print("alt conc: ", prev, next)

    prev = next
    next = len(ALL_FEATURES)
    ALL_FEATURES.extend(alt_per_sec)
    print("alt per sec: ", prev, next)

    prev = next
    next = len(ALL_FEATURES)
    ALL_FEATURES.append(sum(altconc))
    print("sum alt conc: ", prev, next)

    prev = next
    next = len(ALL_FEATURES)
    ALL_FEATURES.append(sum(alt_per_sec))
    print("sum alt per conc: ", prev, next)

    prev = next
    next = len(ALL_FEATURES)
    ALL_FEATURES.append(sum(intertimestats))
    print("sum inter time stats: ", prev, next)

    prev = next
    next = len(ALL_FEATURES)
    ALL_FEATURES.append(sum(timestats))
    print("sum time stats: ", prev, next)

    prev = next
    next = len(ALL_FEATURES)
    ALL_FEATURES.append(sum(number_pkts))
    print("sum number of pkts: ", prev, next)

    prev = next
    next = len(ALL_FEATURES)
    ALL_FEATURES.extend(conc)
    print("Conc: ", prev, next)

    prev = next
    next = len(ALL_FEATURES)
    ALL_FEATURES.extend(per_sec)
    print("per sec: ", prev, next)
    prev = next
    return

# Function to count features and index mapping
def count_fts(lst_fts):
    fnames = ["intertimestats","timestats","number_pkts","thirtypkts","stdconc","avgconc","avg_per_sec","std_per_sec","avg_order_in","avg_order_out","std_order_in",
             "std_order_out","medconc","med_per_sec","min_per_sec","max_per_sec","maxconc","perc_in","perc_out","altconc","alt_per_sec","sum_altconc","sum_alt_per_sec","sum_intertimestats",
             "sum_timestats","sum_number_pkts"]
    print("Total categories: ", len(fnames))
    print(lst_fts)
    print(len(lst_fts), len(fnames))
    ind = 0
    fts = []
    for x in range(0, len(lst_fts)):
        if isinstance(lst_fts[x][0], list):
            if ind == 0:
               start = ind
               ind += len(lst_fts[x][0])-1
            else:
               start = ind+1
               ind += len(lst_fts[x][0])
            print(fnames[x], ":", start, "-", ind, ": ",len(lst_fts[x][0]))
            fts += lst_fts[x][0]
        else:
            ind += 1
            print(fnames[x], ":", ind, ": 1")
            fts += [lst_fts[x][0]]
    print("Feat list len: ", len(fts))


    return

#If size information available add them in to function below
def TOTAL_FEATURES(trace_data, hostfts, onlyhost=False, max_size=175): #175
    if onlyhost:
        hostfts = get_hostfts(trace_data)
        features = hostfts
        #print("HOST FTS: ", features)
        assert len(features) == 40
        return features

    list_data = get_pkt_list(trace_data)
    ALL_FEATURES = []

    intertimestats = [x for x in interarrival_maxminmeansd_stats(list_data)[0]]
    timestats = time_percentile_stats(trace_data)
    number_pkts = list(number_pkt_stats(trace_data))
    thirtypkts = first_and_last_30_pkts_stats(trace_data)
    stdconc, avgconc, medconc, minconc, maxconc, conc = pkt_concentration_stats(trace_data)
    avg_per_sec, std_per_sec, med_per_sec, min_per_sec, max_per_sec, per_sec = number_per_sec(trace_data)
    avg_order_in, avg_order_out, std_order_in, std_order_out = avg_pkt_ordering_stats(trace_data)
    perc_in, perc_out = perc_inc_out(trace_data)

    altconc = []
    alt_per_sec = []
    altconc = [sum(x) for x in chunkIt(conc, 70)]
    alt_per_sec = [sum(x) for x in chunkIt(per_sec, 20)]
    if len(altconc) == 70:
        altconc.append(0)
    if len(alt_per_sec) == 20:
        alt_per_sec.append(0)

    ##count_fts([[intertimestats], [timestats], [number_pkts],[thirtypkts],[stdconc],[avgconc],[avg_per_sec],[std_per_sec],[avg_order_in],[avg_order_out],[std_order_in],[std_order_out],[medconc],[med_per_sec],[min_per_sec],[max_per_sec],[maxconc],[perc_in],[perc_out],[altconc],[alt_per_sec],[sum(altconc)],[sum(alt_per_sec)],[sum(intertimestats)],[sum(timestats)],[sum(number_pkts)]])

    ALL_FEATURES.extend(intertimestats)
    ALL_FEATURES.extend(timestats)
    ALL_FEATURES.extend(number_pkts)
    ALL_FEATURES.extend(thirtypkts)
    ALL_FEATURES.append(stdconc)
    ALL_FEATURES.append(avgconc)
    ALL_FEATURES.append(avg_per_sec)
    ALL_FEATURES.append(std_per_sec)
    ALL_FEATURES.append(avg_order_in)
    ALL_FEATURES.append(avg_order_out)
    ALL_FEATURES.append(std_order_in)
    ALL_FEATURES.append(std_order_out)
    ALL_FEATURES.append(medconc)
    ALL_FEATURES.append(med_per_sec)
    ALL_FEATURES.append(min_per_sec)
    ALL_FEATURES.append(max_per_sec)
    ALL_FEATURES.append(maxconc)
    ALL_FEATURES.append(perc_in)
    ALL_FEATURES.append(perc_out)
    ALL_FEATURES.extend(altconc)
    ALL_FEATURES.extend(alt_per_sec)
    ALL_FEATURES.append(sum(altconc))
    ALL_FEATURES.append(sum(alt_per_sec))
    ALL_FEATURES.append(sum(intertimestats))
    ALL_FEATURES.append(sum(timestats))
    ALL_FEATURES.append(sum(number_pkts))

    print("WF features: ", len(ALL_FEATURES))
    while len(ALL_FEATURES)<max_size:
        ALL_FEATURES.append(0)

    if hostfts:
        wffts = ALL_FEATURES[:max_size]
        hostfts = get_hostfts(trace_data)
        features = wffts + hostfts
        print(hostfts)
        assert len(features) == 215
    else:
        features = ALL_FEATURES[:max_size]

    return features

def chunks(l, n):
    """ Yield successive n-sized chunks from l."""
    for i in xrange(0, len(l), n):
        yield l[i:i+n]

def checkequal(lst):
    return lst[1:] == lst[:-1]


